/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.lvyh.lightframe.transaction.core.repository.impl;

import com.alibaba.druid.pool.DruidDataSource;
import com.lvyh.lightframe.transaction.common.config.TransactionApplicationConfig;
import com.lvyh.lightframe.transaction.common.config.repository.JdbcRepositoryConfig;
import com.lvyh.lightframe.transaction.common.constant.TransactionConstant;
import com.lvyh.lightframe.transaction.common.domain.Participant;
import com.lvyh.lightframe.transaction.common.domain.Transaction;
import com.lvyh.lightframe.transaction.common.enums.TransactionStatus;
import com.lvyh.lightframe.transaction.core.ext.Spi;
import com.lvyh.lightframe.transaction.core.repository.TransactionLogRepository;
import com.lvyh.lightframe.transaction.core.serialize.Serializer;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.stream.Collectors;

/**
 * JDBC transaction log storage support
 */
@Slf4j
@Spi("jdbcrepository")
public class JdbcTransactionRepository implements TransactionLogRepository {

    private Serializer serializer;
    private DruidDataSource dataSource;
    private String tableName;

    @Override
    public void setObjectSerializer(Serializer serializer) {
        this.serializer = serializer;
    }

    @Override
    public String getRepositoryName() {
        return TransactionConstant.TRANSACTION_REPOSITORY_JDBC;
    }

    @Override
    public void init(TransactionApplicationConfig transactionApplicationConfig) {
        dataSource = new DruidDataSource();
        JdbcRepositoryConfig jdbcConfig = (JdbcRepositoryConfig) transactionApplicationConfig.getRepositoryConfig();
        BeanUtils.copyProperties(jdbcConfig, dataSource);
        tableName = TransactionConstant.TRANSACTION_LOG_PREFIX + "_" + jdbcConfig.getTableSuffix().replaceAll("-", "_");
        //Create table
        executeUpdate(buildCreateTableSql(tableName));
    }

    @Override
    public Integer insert(Transaction transaction) {
        StringBuilder sql = new StringBuilder()
                .append("insert into ")
                .append(tableName)
                .append("(transaction_id,target_class,target_method,retried_count,send_message_count,create_time,last_time,status,invocation,role,error_message)")
                .append(" values(?,?,?,?,?,?,?,?,?,?,?)");
        try {
            //Serialize the set of transaction participants into binary data
            byte[] participantSerialize = serializer.serialize(transaction.getParticipants());
            return executeUpdate(sql.toString(),
                    transaction.getTransactionId(),
                    transaction.getTargetClass(),
                    transaction.getTargetMethod(),
                    transaction.getRetriedCount(),
                    transaction.getSendMessageCount(),
                    transaction.getCreateTime(),
                    transaction.getLastTime(),
                    transaction.getStatus(),
                    participantSerialize,
                    transaction.getRole(),
                    transaction.getErrorMessage());
        } catch (Exception e) {
            e.printStackTrace();
            return TransactionConstant.JDBC_ERROR;
        }
    }

    @Override
    public Integer update(Transaction transaction) {
        transaction.setLastTime(new Date());
        String sql = "update " + tableName + " set last_time = ?,retried_count = ?,send_message_count = ?,invocation = ?,status = ?,error_message = ? where transaction_id = ?";
        try {
            byte[] participantSerialize = serializer.serialize(transaction.getParticipants());
            return executeUpdate(sql,
                    transaction.getLastTime(),
                    transaction.getRetriedCount(),
                    transaction.getSendMessageCount(),
                    participantSerialize,
                    transaction.getStatus(),
                    transaction.getErrorMessage(),
                    transaction.getTransactionId());
        } catch (Exception e) {
            e.printStackTrace();
            return TransactionConstant.JDBC_ERROR;
        }
    }

    @Override
    public Transaction getById(String transactionId) {
        String selectSql = "select * from " + tableName + " where transaction_id=?";
        List<Map<String, Object>> list = executeQuery(selectSql, transactionId);
        if (Objects.nonNull(list) && list.size() > 0) {
            return list.stream().filter(Objects::nonNull).map(this::buildByResultMap).findFirst().get();
        }
        return null;
    }

    @Override
    public List<Transaction> findRecover(Date date, Integer retriedPeriod) {
        String selectSql = "select * from " + tableName + " where last_time < ?  and status = " + TransactionStatus.FAILURE.getCode() + " and send_message_count < " + retriedPeriod;
        List<Map<String, Object>> list = executeQuery(selectSql, date);
        if (Objects.nonNull(list) && list.size() > 0) {
            return list.stream().filter(Objects::nonNull).map(this::buildByResultMap).collect(Collectors.toList());
        }
        return null;
    }

    private String buildCreateTableSql(String tableName) {
        StringBuilder sql = new StringBuilder();
        sql.append("CREATE TABLE IF NOT EXISTS `")
                .append(tableName).append("` (\n")
                .append("  `transaction_id` varchar(64) NOT NULL COMMENT '全局事务id',\n")
                .append("  `target_class` varchar(256) DEFAULT NULL COMMENT '当前事务执行服务方法所在类',\n")
                .append("  `target_method` varchar(128) DEFAULT NULL COMMENT '当前事务执行服务方法',\n")
                .append("  `retried_count` int(3) DEFAULT NULL COMMENT '重试次数',\n")
                .append("  `send_message_count` int(3) DEFAULT NULL COMMENT '发送事务消息次数',\n")
                .append("  `create_time` datetime DEFAULT NULL COMMENT '事务日志创建时间',\n")
                .append("  `last_time` datetime DEFAULT NULL COMMENT '上次修改时间',\n")
                .append("  `status` int(2) DEFAULT NULL COMMENT '事务状态，0-回滚，1-开始，2-预提交，3-提交',\n")
                .append("  `invocation` longblob COMMENT '当前事务关联的事务参与者信息序列化字节',\n")
                .append("  `role` int(2) DEFAULT NULL COMMENT '事务角色',\n")
                .append("  `error_message` text COMMENT '服务调用失败信息',\n")
                .append("   PRIMARY KEY (`transaction_id`)\n")
                .append(") ENGINE=InnoDB DEFAULT CHARSET=utf8 COMMENT='本地事务日志表'; ");
        return sql.toString();
    }

    /**
     * Convert query results to transaction objects
     */
    private Transaction buildByResultMap(final Map<String, Object> map) {
        Transaction transaction = new Transaction();
        transaction.setTransactionId((String) map.get("transaction_id"));
        transaction.setRetriedCount((Integer) map.get("retried_count"));
        transaction.setSendMessageCount((Integer) map.get("send_message_count"));
        transaction.setCreateTime((Date) map.get("create_time"));
        transaction.setLastTime((Date) map.get("last_time"));
        transaction.setStatus((Integer) map.get("status"));
        transaction.setRole((Integer) map.get("role"));
        transaction.setErrorMessage((String) map.get("error_message"));
        transaction.setTargetClass((String) map.get("target_class"));
        transaction.setTargetMethod((String) map.get("target_method"));
        byte[] bytes = (byte[]) map.get("invocation");
        try {
            List<Participant> participants = serializer.deserialize(bytes, CopyOnWriteArrayList.class);
            transaction.setParticipants(participants);
        } catch (Exception e) {
            e.printStackTrace();
        }
        return transaction;
    }

    private int executeUpdate(final String sql, final Object... params) {
        try (Connection connection = dataSource.getConnection();
             PreparedStatement ps = connection.prepareStatement(sql)) {
            if (params != null) {
                for (int i = 0; i < params.length; i++) {
                    ps.setObject(i + 1, params[i]);
                }
            }
            return ps.executeUpdate();
        } catch (SQLException e) {
            log.error("executeUpdate-> " + e.getMessage());
        }
        return 0;
    }

    private List<Map<String, Object>> executeQuery(final String sql, final Object... params) {
        List<Map<String, Object>> list = null;
        try (Connection connection = dataSource.getConnection();
             PreparedStatement ps = connection.prepareStatement(sql)) {
            if (params != null) {
                for (int i = 0; i < params.length; i++) {
                    ps.setObject(i + 1, params[i]);
                }
            }
            try (ResultSet rs = ps.executeQuery()) {
                ResultSetMetaData md = rs.getMetaData();
                int columnCount = md.getColumnCount();
                list = new ArrayList<>();
                while (rs.next()) {
                    Map<String, Object> rowData = new HashMap<>(16);
                    for (int i = 1; i <= columnCount; i++) {
                        rowData.put(md.getColumnName(i), rs.getObject(i));
                    }
                    list.add(rowData);
                }
            }
        } catch (SQLException e) {
            log.error("executeQuery-> " + e.getMessage());
        }
        return list;
    }
}
